#! /usr/bin/env python

import rospy
import tf
from nav_msgs.msg import OccupancyGrid
from numpy import ones, sum
from std_msgs.msg import Float32, Header
from std_srvs.srv import Trigger

# Constants for more readable index lookup
X, Y, Z, W = 0, 1, 2, 3


class CoverageProgressNode(object):
    """The CoverageProgressNode keeps track of coverage progress.
    It does this by periodically looking up the position of the coverage disk in an occupancy grid.
    Cells within a radius from this position are 'covered'

    Cell values are interpreted in this way: Lower is covered, higher is less covered
    - 100: uncovered (initial value)
    - < 100: covered

    The node emits a coverage progress,
        which is the ratio of cells considered coverage (0.0 is everything uncovered, 1.0 is all covered)
    """

    DIRTY = 100

    def __init__(self):
        self.listener = tf.TransformListener()

        x = None  # type: float
        y = None  # type: float
        self._coverage_area = None  # type: Tuple[float, float]

        self.coverage_resolution = None  # type: float  # How big is a cell [m]

        # How well covered is a cell after it has been covered for 1 time step
        self.coverage_effectivity = None  # type: int

        self.map_frame = None  # type: str
        self.coverage_frame = None  # type: str

        self.grid = self._initialize_map()

        self.progress_pub = rospy.Publisher("coverage_progress", Float32, queue_size=1)
        self.grid_pub = rospy.Publisher("coverage_grid", OccupancyGrid, queue_size=1)

        self.reset_srv = rospy.Service('reset', Trigger, self.reset)

        self._rate = rospy.get_param("~rate", 10.0)
        self._update_timer = rospy.Timer(rospy.Duration(1.0/self._rate), self._update_callback)

    def _initialize_map(self):
        # Initialize coverage matrix

        # Define parameters
        x = rospy.get_param("~target_area/x", 10.0)  # height of the area to cover, in x direction of the map
        y = rospy.get_param("~target_area/y", 5.0)  # width of the area to cover, in y direction of the map
        self._coverage_area = (x, y)

        try:
            self.coverage_radius_meters = rospy.get_param("~coverage_radius")
        except KeyError:
            try:
                self.coverage_radius_meters = rospy.get_param("~coverage_width") / 2.0
            except KeyError:
                rospy.logerr("Specify either coverage_width or coverage_radius")
                raise ValueError("Neither ~coverage_radius nor ~coverage_width specified, one of these is required")

        self.coverage_resolution = rospy.get_param("~coverage_resolution", 0.05)  # How big is a cell [m]

        # How much covered is a cell after it has been covered for 1 time step
        self.coverage_effectivity = rospy.get_param("~coverage_effectivity", 5)

        self.map_frame = rospy.get_param("~map_frame", "map")
        self.coverage_frame = rospy.get_param("~coverage_frame", "base_link")

        self.coverage_radius_meters += 2 * self.coverage_resolution  # Compensate for discretization
        self.coverage_radius_cells = int((self.coverage_radius_meters) / self.coverage_resolution)

        grid = OccupancyGrid()
        grid.info.resolution = self.coverage_resolution

        grid.info.width = abs(int(self._coverage_area[X] / self.coverage_resolution))
        grid.info.height = abs(int(self._coverage_area[Y] / self.coverage_resolution))

        grid.info.origin.position.x = 0 if self._coverage_area[X] > 0 else self._coverage_area[X]
        grid.info.origin.position.y = 0 if self._coverage_area[Y] > 0 else self._coverage_area[Y]
        grid.info.origin.orientation.w = 1

        # Initialize OccupancyGrid to have all cells DIRTY
        grid.data = self.DIRTY * ones(grid.info.width * grid.info.height)

        return grid

    def _update_callback(self, event):
        # Get the position of point (0,0,0) the coverage_disk frame wrt. the map frame (which can both be remapped if need be)

        try:
            (coveragepos, rot) = self.listener.lookupTransform(self.map_frame, self.coverage_frame, rospy.Time(0))
        except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
            return

        # Element of matrix corresponding to middle of coverage surface
        x_point = int((coveragepos[X] - self.grid.info.origin.position.x) / self.coverage_resolution)
        y_point = int((coveragepos[Y] - self.grid.info.origin.position.y) / self.coverage_resolution)

        # Initialize message
        self.grid.header = Header()
        self.grid.header.frame_id = self.map_frame

        # Loop over amount of cells covered in x (j) and y (i) direction
        for i in range(2 * self.coverage_radius_cells):
            for j in range(2 * self.coverage_radius_cells):

                x_index = j - self.coverage_radius_cells
                y_index = i - self.coverage_radius_cells

                array_index = x_point + x_index + self.grid.info.width * (y_point + y_index)

                cell_in_coverage_circle = x_index ** 2 + y_index ** 2 < self.coverage_radius_cells ** 2

                cell_in_grid = 0 <= x_point + x_index < abs(int(self._coverage_area[X] / self.coverage_resolution)) \
                               and 0 <= y_point + y_index < abs(int(self._coverage_area[Y] / self.coverage_resolution))

                if cell_in_coverage_circle and cell_in_grid:
                    self.grid.data[array_index] = max(0, self.grid.data[array_index] - self.coverage_effectivity)
        else:
            rospy.logdebug("x_point %i y_point %i, x_meas %f, y_meas %f", x_point, y_point, coveragepos[X],
                           coveragepos[Y])

        coverage_progress = float(sum([self.grid.data < self.DIRTY])) / (self.grid.info.width * self.grid.info.height)

        self.progress_pub.publish(coverage_progress)
        self.grid_pub.publish(self.grid)

    def finish_callback(self, msg):

        if msg:
            coverage_progress = float(sum([self.grid.data < self.DIRTY])) / (self.grid.info.width * self.grid.info.height)

            self.progress_pub.publish(coverage_progress)

    def reset(self, srv_request):
        rospy.loginfo("Reset coverage progress and grid")
        self.grid = self._initialize_map()
        return (True, "Reset coverage progress and grid")

if __name__ == '__main__':
    rospy.init_node('coverage_progress')
    try:
        node = CoverageProgressNode()
        rospy.spin()
    except rospy.ROSInterruptException:
        pass
